#include<cstdio>
#include<iostream>
#include<cstdlib>
#include<algorithm>
#include<cmath>
#include<cstring>
using namespace std;
int n,m;
int tot;
int main()
{	
	freopen("bpmp.in","r",stdin);
	freopen("bpmp.out","w",stdout);
	scanf("%d%d",&n,&m);
	int a=max(n,m);
	int b=min(n,m);
	int c=min(n,m);
	for(int i=1;i<=c;i++)
	{	if (b-i)
		tot+=(a-i+1)%998244353+(b-i)%998244353;
		else 
	{
			tot+=(a-i)%998244353;
		break;}
	}
printf("%d",tot%998244353);
	return 0;
}
